#define ARMA_DONT_PRINT_ERRORS

#include <RcppArmadillo.h>
#include <gsl/gsl_rng.h>
#include <gsl/gsl_randist.h>
// [[Rcpp::depends(RcppArmadillo)]]
// [[Rcpp::depends(RcppGSL)]]

#include <Ziggurat.h>
// [[Rcpp::depends(RcppZiggurat)]]

static Ziggurat::Ziggurat::Ziggurat zigg;




//[[Rcpp:plugins(cpp11)]]
#include <omp.h>
//[[Rcpp:plugins(openmp)]]


const double log2pi = std::log(2.0 * M_PI);



double truncn(double bound, bool lb, double mu, double sigma){
  
  double c, z, w;
  
  // 1. standardised cut-off c for truncation from below or above 
  if(lb == TRUE){
    c = (bound-mu)/sigma;
  } else{
    c = -(bound-mu)/sigma;
  }
  
  // 2. standardised draw using Geweke's (1991)
  if(c < 0.45){ // normal rejection sampling
   // z = ::Rf_rnorm(0.0,1.0);
	z=zigg.norm();
    while(z < c){
     // z = ::Rf_rnorm(0.0,1.0);
	z=zigg.norm();
    } 
  } else{ // exponential rejection sampling
    z = -log(1-::Rf_runif(0.0,1.0))/c;
    w = ::Rf_runif(0.0,1.0);
    while(w > exp(-0.5*pow(z,2))){
      z = -log(1-::Rf_runif(0.0,1.0))/c;
      w = ::Rf_runif(0.0,1.0);
    }
    z = z+c;
  }

  // 3. reverse standardisation
  if(lb == TRUE){
    return mu + sigma*z;
  } else{
    return mu - sigma*z;
  }
}



 double get1TN(double mu,
 double sd,
 double low,
 double high
 ) {
 double draw = 0.0 ;
int valid = 0 ;
while (valid==0) {
double cand = mu+zigg.norm()*sd;
if ((cand >= low) &
 (cand <= high)
 ) {
 draw = cand ;
 valid = 1 ;
 }
 }
 return(draw) ;
}


double vm_mult(const arma::mat& rhs)
{
  arma::mat ip= rhs.t() * rhs;
 return (ip(0));
}

double vm_convert(const arma::mat& rhs)
{
  arma::mat ip= rhs;
 return (ip(0));
}

double am_mult2(const arma::mat& rhs)
{
  arma::mat ip= rhs.t();
 return (ip(0));
}





// [[Rcpp::export()]]
Rcpp::List mcmc_probit (arma::mat Y,
arma::mat I,
arma::vec tablecountry,
arma::colvec Cinit_conv,
arma::colvec Cinit_unconv,
arma::mat alphastart,
arma::mat alphastart_conv,
arma::mat alphastart_unconv,
int mcmc = 100,
int burn=10,
int thin=10,
int thin_latent=10,
int chains=2
 ) {

 // Dimensions
int N = Y.n_rows ;
int nitems = Y.n_cols;
int ncountry=I.n_cols;


 // Current Containers

arma::mat index_varCconv(N, nitems);
index_varCconv.fill(0.0) ;
arma::mat index_meanCconv(N, nitems);
index_meanCconv.fill(0.0) ;

arma::mat index_varCunconv(N, nitems);
index_varCunconv.fill(0.0) ;
arma::mat index_meanCunconv(N, nitems);
index_meanCunconv.fill(0.0) ;

arma::mat ystar(N, nitems) ;
ystar.fill(0.0) ;
 arma::mat mu(N, nitems) ;
 arma::colvec C_conv = Cinit_conv ;
 arma::colvec C_unconv = Cinit_unconv ;
 arma::mat alpha = alphastart ;
 arma::mat alpha_conv = alphastart_conv ;
 arma::mat alpha_unconv = alphastart_unconv ;
 arma::vec sigma2_alphaintercept(nitems);
 sigma2_alphaintercept.fill(1.0) ;

 arma::vec sigma2_alphaconv(nitems);
 sigma2_alphaconv.fill(1.0) ;

 arma::vec sigma2_alphaunconv(nitems);
 sigma2_alphaunconv.fill(1.0) ;



 int lastit= (mcmc-burn)/thin	;
 int lastit_latent=(mcmc-burn)/thin_latent;

// Trace Containers
  arma::cube tracealpha(nitems*ncountry, lastit,chains) ;
  arma::cube tracealphaconv(nitems*ncountry, lastit,chains) ;
  arma::cube tracealphaunconv(nitems*ncountry, lastit,chains) ;
  arma::cube tracesigma2alpha(nitems,lastit,chains);
  arma::cube tracesigma2alphaconv(nitems,lastit,chains);
  arma::cube tracesigma2alphaunconv(nitems,lastit,chains);
 arma::cube traceCconv(N, lastit_latent, chains);
 arma::cube traceCunconv(N, lastit_latent, chains);

omp_set_num_threads(chains);

#pragma omp parallel for 
for (int chain = 0; chain < chains; ++chain) {

gsl_rng *s = gsl_rng_alloc(gsl_rng_mt19937);


for (int iter = 0 ; iter < mcmc ; iter++) {

 // START SAMPLING
// Update C_conv

index_varCconv=I*alpha_conv.t();

for (int l=0; l<nitems; l++) {
index_meanCconv.col(l)=(I*alpha_conv.row(l).t())%(ystar.col(l)-I*alpha.row(l).t()-(I*alpha_unconv.row(l).t())%(C_unconv));
}


for (int n = 0 ; n < N ; n++) {
double vCconv=1/(sum(index_varCconv.row(n)*index_varCconv.row(n).t())+1);
double bCconv=vCconv*sum(index_meanCconv.row(n));
C_conv(n)= bCconv+gsl_ran_gaussian_ziggurat(s,sqrt(vCconv)); 


}


// Update C_unconv

index_varCunconv=I*alpha_unconv.t();

for (int l=0; l<nitems; l++) {
index_meanCunconv.col(l)=(I*alpha_unconv.row(l).t())%(ystar.col(l)-I*alpha.row(l).t()-(I*alpha_conv.row(l).t())%(C_conv));
}


for (int n = 0 ; n < N ; n++) {
double vCunconv=1/(sum(index_varCunconv.row(n)*index_varCunconv.row(n).t())+1);
double bCunconv=vCunconv*sum(index_meanCunconv.row(n));
C_unconv(n)= bCunconv+gsl_ran_gaussian_ziggurat(s,sqrt(vCunconv)); 

}



// Update item-specific parameters
for (int l=0; l<nitems; l++) {

mu.col(l) = I*alpha.row(l).t()+(I*alpha_conv.row(l).t())%(C_conv)+(I*alpha_unconv.row(l).t())%(C_unconv);

for (int n = 0 ; n < N ; n++) {
 if (Y(n, l) == 1) {
 ystar(n, l) = get1TN(mu(n, l),
 1,
 0,
 INFINITY
 ) ;
 }
 if (Y(n, l) == 0) {
 ystar(n, l) = get1TN(mu(n, l),
 1,
 -INFINITY,
 0
 ) ;
 }
 }

arma::mat llreffectalpha=I.t()*(ystar.col(l)-(I*alpha_conv.row(l).t())%(C_conv)-(I*alpha_unconv.row(l).t())%(C_unconv));


for (int c=0; c<ncountry; c++) {
double valpha=1/(tablecountry(c)+(1/sigma2_alphaintercept(l)));
double balpha=valpha * llreffectalpha(c);
alpha(l,c)= balpha+gsl_ran_gaussian_ziggurat(s,sqrt(valpha)); 

}

sigma2_alphaintercept(l)=1.0/gsl_ran_gamma(s, 1+ncountry/2, 1.0/(1.0+vm_mult(alpha.row(l))/2));
}


for (int l=0; l<(nitems-1); l++) {
for (int c=0; c<ncountry; c++) {
double valpha_conv=1/(vm_mult(I.col(c)%(C_conv-1))+(1/sigma2_alphaconv(l)));
double balpha_conv=valpha_conv * vm_convert((I.col(c)%(C_conv-1)).t()*(I.col(c)%(ystar.col(l)-I*alpha.row(l).t()-(I*alpha_unconv.row(l).t())%(C_unconv-1))));
alpha_conv(l,c)=truncn(0.0,TRUE, balpha_conv,  sqrt(valpha_conv));
}

sigma2_alphaconv(l)=1.0/gsl_ran_gamma(s, 1+ncountry/2, 1.0/(1.0+vm_mult(alpha_conv.row(l))/2));


}



for (int l=1; l<nitems; l++) {
for (int c=0; c<ncountry; c++){
double valpha_unconv=1/(vm_mult(I.col(c)%(C_unconv-1))+(1/sigma2_alphaunconv(l)));
double balpha_unconv=valpha_unconv * vm_convert((I.col(c)%(C_unconv-1)).t()*(I.col(c)%(ystar.col(l)-I*alpha.row(l).t()-(I*alpha_conv.row(l).t())%(C_conv-1))));
alpha_unconv(l,c)=truncn(0.0,TRUE, balpha_unconv,  sqrt(valpha_unconv));
}

sigma2_alphaunconv(l)=1.0/gsl_ran_gamma(s, 1+ncountry/2, 1.0/(1.0+vm_mult(alpha_unconv.row(l))/2));

}


 // TRACE
if ((iter+1)> burn & (iter+1)%thin==0) {
int j=((iter)-burn)/thin;
tracealpha.subcube(0, j, chain, nitems*ncountry-1, j, chain)=vectorise(alpha);
tracealphaconv.subcube(0, j, chain, nitems*ncountry-1, j, chain)=vectorise(alpha_conv);
tracealphaunconv.subcube(0, j, chain, nitems*ncountry-1, j, chain)=vectorise(alpha_unconv);
tracesigma2alpha.subcube(0,j,chain,nitems-1,j,chain)=sigma2_alphaintercept;
tracesigma2alphaconv.subcube(0,j,chain,nitems-1,j,chain)=sigma2_alphaconv;
tracesigma2alphaunconv.subcube(0,j,chain,nitems-1,j,chain)=sigma2_alphaunconv;
}

if ((iter+1)> burn & (iter+1)%thin_latent==0) {
int k=((iter)-burn)/thin_latent;
traceCconv.subcube(0, k, chain, N-1, k, chain)=vectorise(C_conv);
traceCunconv.subcube(0, k, chain, N-1, k, chain)=vectorise(C_unconv);
}


} 
gsl_rng_free(s);
}

// Returns
Rcpp::List ret;

 ret["alpha"] = tracealpha;
 ret["alpha_conv"] = tracealphaconv;
 ret["alpha_unconv"] = tracealphaunconv;
ret["sigma2_alphaintercept"] = tracesigma2alpha;
ret["sigma2_alphaconv"] = tracesigma2alphaconv;
ret["sigma2_alphaunconv"] = tracesigma2alphaunconv;
ret["C_conv"]=traceCconv;
ret["C_unconv"]=traceCunconv;

 return(ret) ;
 

}

